Skip to content

Conversation

@fegin
Copy link
Contributor

@fegin fegin commented Feb 5, 2026

Stack from ghstack (oldest at bottom):

Why this PR

Understanding different DTensor sharding is import especially when doing full dtensor project. Creating this tool for debugging purpose.

What this PR does

This PR adds a sharding debug tool that captures and visualizes DTensor sharding information during training. When enabled via debug.log_sharding_info=True, it registers forward and backward hooks on all modules to record tensor placements, device mesh info, and shapes for one forward/backward pass. The tool outputs both a formatted ASCII text file and an interactive HTML visualization.

Limitation:

This tool can only track 1) module inputs/outputs and the gradients and 2) module states and the gradients. Any activation that generate by ops that is not a module can not be tracked. We will have to use TorchFunctionMode or TorchDispatchMode to do this.

For Reviewers:
UX functions (ASCII and html) are completely generated by Claude. I'm not an experienced frontend developer and didn't code review too much html file.

NGPU=8 COMM_MODE=fake_backend CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh  --parallelism.tensor_parallel_degree=8 --debug.log_sharding_info
Screenshot 2026-02-04 at 10 48 35 PM

[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 5, 2026
**Why this PR**

Understanding different DTensor sharding is import especially when doing full dtensor project. Creating this tool for debugging purpose.

**What this PR does**

This PR adds a sharding debug tool that captures and visualizes DTensor sharding information during training. When enabled via `debug.log_sharding_info=True`, it registers forward and backward hooks on all modules to record tensor placements, device mesh info, and shapes for one forward/backward pass. The tool outputs both a formatted ASCII text file and an interactive HTML visualization.

**Limitation:**

This tool can only track 1) module inputs/outputs and the gradients and 2) module states and the gradients. Any activation that generate by ops that is not a module can not be tracked. We will have to use TorchFunctionMode or TorchDispatchMode to do this.

**For Reviewers:**
UX functions (ASCII and html) are completely generated by Claude. I'm not an experienced frontend developer and didn't code review too much for those files.

```
NGPU=8 COMM_MODE=fake_backend CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh  --parallelism.tensor_parallel_degree=8 --debug.log_sharding_info
```


ghstack-source-id: bcc7749
Pull-Request: #2328
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 5, 2026
[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 6, 2026
**Why this PR**

Understanding different DTensor sharding is import especially when doing full dtensor project. Creating this tool for debugging purpose.

**What this PR does**

This PR adds a sharding debug tool that captures and visualizes DTensor sharding information during training. When enabled via `debug.log_sharding_info=True`, it registers forward and backward hooks on all modules to record tensor placements, device mesh info, and shapes for one forward/backward pass. The tool outputs both a formatted ASCII text file and an interactive HTML visualization.

**Limitation:**

This tool can only track 1) module inputs/outputs and the gradients and 2) module states and the gradients. Any activation that generate by ops that is not a module can not be tracked. We will have to use TorchFunctionMode or TorchDispatchMode to do this.

**For Reviewers:**
UX functions (ASCII and html) are completely generated by Claude. I'm not an experienced frontend developer and didn't code review too much for those files.

```
NGPU=8 COMM_MODE=fake_backend CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh  --parallelism.tensor_parallel_degree=8 --debug.log_sharding_info
```

ghstack-source-id: 5354c4e
Pull-Request: #2328
Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is general useful for all DTensor developers, not only titan developers (and we would expect infra developers would be interested in this tool, not model researchers). Should we put it in PyTorch , like flight_recorder?

@fegin
Copy link
Contributor Author

fegin commented Feb 6, 2026

This is general useful for all DTensor developers, not only titan developers
This is True.

(and we would expect infra developers would be interested in this tool, not model researchers)

Most our users also care about scaling, so this is useful for them too.

Should we put it in PyTorch , like flight_recorder?

The reason why I don't put it in PyTorch, is because this tool is not general enough yet. The tool is likely to be polished continuously when we start to use it. After it is mature enough, we can upstream to PyTorch.

@fegin fegin requested a review from wwwjn February 6, 2026 04:45
Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm wondering about the best way to land this. it seems nice. It's also a big pile of code.

  • it's not obvious that it is torchtitan specific. should some/all of this go into torch itself?

is there a nice way to decouple a 'core' piece that lands in torch- perhaps

  • a contextmanager that itself produces a well defined data structure
  • a clearly separate plugin that takes the data and renders it
    in this way, you can keep the html stuff out of tree

@tianyu-l
Copy link
Contributor

tianyu-l commented Feb 7, 2026

I think it partially addressed the pain point of sharding info not being integral part of nn.Module, in the way that we explicitly insert hooks to record what happens, somewhat similar to

Can we merge/iterate with one of those tools in pytorch?

@fegin
Copy link
Contributor Author

fegin commented Feb 10, 2026

CommDebugMode integration #480 made by @anshul-si earlier. DebugMode which is more about aten level recording.

@tianyu-l
Correct, I'm aware of CommDebugMode, but as you mentioned, it is more about aten level, which is not easy to understand the sharding of a module. I thought about merging the two but currently no clear idea how to make the two coexist in CommDebugMode without making UX too bad or the outputs too verbose. But I need it for debugging purposes. So I would prefer landing in TorchTitan and iterate the UX and upstream it later when it is more general. But I can land it to full_dtensor branch first since it is not general enough and we have a separate branch for development purposes.

cc., @wconstab @wwwjn The same answer to your questions as well.

@wwwjn
Copy link
Contributor

wwwjn commented Feb 10, 2026

CommDebugMode integration #480 made by @anshul-si earlier. DebugMode which is more about aten level recording.

@tianyu-l Correct, I'm aware of CommDebugMode, but as you mentioned, it is more about aten level, which is not easy to understand the sharding of a module. I thought about merging the two but currently no clear idea how to make the two coexist in CommDebugMode without making UX too bad or the outputs too verbose. But I need it for debugging purposes. So I would prefer landing in TorchTitan and iterate the UX and upstream it later when it is more general. But I can land it to full_dtensor branch first since it is not general enough and we have a separate branch for development purposes.

cc., @wconstab @wwwjn The same answer to your questions as well.

I think we could put it in a separate branch for now. And using html might be too burdensome as most of the info are structured, can we put it simply in a json file? That would also simplify the land process

@fegin
Copy link
Contributor Author

fegin commented Feb 10, 2026

@wwwjn Let's put it in anther branch for now. But I'm not sure if I agree with the html part. html is more readable even though this PR also generates ascii file, I merely use it for programming verification purpose. Mostly I read html. tlparse also use html. Let's iterate in another branch to understand what's the best way to upstream the tool.

@anshul-si
Copy link

@fegin I'm unsure how much work it would be to integrate into CommDebugMode.
image

As you can see, commdebugmode does include some of the information you're already trying to output. Furthermore, commdebugmode uses noise levels to control how much information is output. In this case, you could either just create a different output function just for your output, or make your work the minimum output noise-level where no ops are shown, its just module sharding information. In addition, there technically was a html
Screenshot 2026-02-10 at 4 18 19 PM you could use, but i dont think it was ever landed, though claude should be able to do this part pretty easily

@fegin
Copy link
Contributor Author

fegin commented Feb 11, 2026

@anshul-si Yes, I saw that before :) But the main information I need is simply the input, state, and output sharding of a module. I only see state sharding. I guess one can try to infer from the aten order. But it is better to explicitly capture this information.

I think one way going forward is to enhance CommDebugMode to capture input and output sharding. The rest will just be UX improvement, like having an option to filter out aten ops, an option to fold the repeat modules, etc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants